syntax = "proto3";

package tensorflow.tensorforest;

import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto";

// Leaf models specify what is returned at inference time, and how it is
// stored in the decision_trees.Leaf protos.
enum LeafModelType {
  MODEL_DENSE_CLASSIFICATION = 0;
  MODEL_SPARSE_CLASSIFICATION = 1;
  MODEL_REGRESSION = 2;
  MODEL_SPARSE_OR_DENSE_CLASSIFICATION = 3;
}

// Stats models generally specify information that is collected which is
// necessary to choose a split at a node. Specifically, they operate on
// a SplitCandidate::LeafStat proto.
enum StatsModelType {
  STATS_DENSE_GINI = 0;
  STATS_SPARSE_GINI = 1;
  STATS_LEAST_SQUARES_REGRESSION = 2;
  // STATS_SPARSE_THEN_DENSE_GINI is deprecated and no longer supported.
  STATS_SPARSE_THEN_DENSE_GINI = 3;
  STATS_FIXED_SIZE_SPARSE_GINI = 4;
}

// Allows selection of operations on the collection of split candidates.
// Basic infers right split stats from the leaf stats and each candidate's
// left stats.
enum SplitCollectionType {
  COLLECTION_BASIC = 0;
  GRAPH_RUNNER_COLLECTION = 1;
}

// Pruning strategies define how candidates are pruned over time.
// SPLIT_PRUNE_HALF prunes the worst half of splits every prune_ever_samples,
// etc.  Note that prune_every_samples plays against the depth-dependent
// split_after_samples, so they should be set together.
enum SplitPruningStrategyType {
  SPLIT_PRUNE_NONE = 0;
  SPLIT_PRUNE_HALF = 1;
  SPLIT_PRUNE_QUARTER = 2;
  SPLIT_PRUNE_10_PERCENT = 3;
  // SPLIT_PRUNE_HOEFFDING prunes splits whose Gini impurity is worst than
  // the best split's by more than the Hoeffding bound.
  SPLIT_PRUNE_HOEFFDING = 4;
}

message SplitPruningConfig {
  DepthDependentParam prune_every_samples = 1;
  SplitPruningStrategyType type = 2;
}

// Finish strategies define when slots are considered finished.
// Basic requires at least split_after_samples, and doesn't allow slots to
// finish until the leaf has received more than one class. Hoeffding splits
// early after min_split_samples if one split is dominating the rest according
// to hoeffding bounds. Bootstrap does the same but compares gini's calculated
// with sampled smoothed counts.
enum SplitFinishStrategyType {
  SPLIT_FINISH_BASIC = 0;
  SPLIT_FINISH_DOMINATE_HOEFFDING = 2;
  SPLIT_FINISH_DOMINATE_BOOTSTRAP = 3;
}

message SplitFinishConfig {
  // Configure how often we check for finish, because some finish methods
  // are expensive to perform.
  DepthDependentParam check_every_steps = 1;
  SplitFinishStrategyType type = 2;
}

// A parameter that changes linearly with depth, with upper and lower bounds.
message LinearParam {
  float slope = 1;
  float y_intercept = 2;
  float min_val = 3;
  float max_val = 4;
}

// A parameter that changes expoentially with the form
//     f = c + mb^(k*d)
// where:
//  c: constant bias
//  b: base
//  m: multiplier
//  k: depth multiplier
//  d: depth
message ExponentialParam {
  float bias = 1;
  float base = 2;
  float multiplier = 3;
  float depth_multiplier = 4;
}

// A parameter that is 'off' until depth >= a threshold, then is 'on'.
message ThresholdParam {
  float on_value = 1;
  float off_value = 2;
  float threshold = 3;
}

// A parameter that may change with node depth.
message DepthDependentParam {
  oneof ParamType {
    float constant_value = 1;
    LinearParam linear = 2;
    ExponentialParam exponential = 3;
    ThresholdParam threshold = 4;
 }
}

message TensorForestParams {
  // ------------ Types that control training subsystems ------ //
  LeafModelType leaf_type = 1;
  StatsModelType stats_type = 2;
  SplitCollectionType collection_type = 3;
  SplitPruningConfig pruning_type = 4;
  SplitFinishConfig finish_type = 5;

  // --------- Parameters that can't change by definition --------------- //
  int32 num_trees = 6;
  int32 max_nodes = 7;
  int32 num_features = 21;

  decision_trees.InequalityTest.Type inequality_test_type = 19;

  // Some booleans controlling execution
  bool is_regression = 8;
  bool drop_final_class = 9;
  bool collate_examples = 10;
  bool checkpoint_stats = 11;
  bool use_running_stats_method = 20;
  bool initialize_average_splits = 22;
  bool inference_tree_paths = 23;

  // Number of classes (classification) or targets (regression)
  int32 num_outputs = 12;

  // --------- Parameters that could be depth-dependent --------------- //
  DepthDependentParam num_splits_to_consider = 13;
  DepthDependentParam split_after_samples = 14;
  DepthDependentParam dominate_fraction = 15;
  DepthDependentParam min_split_samples = 18;

  // --------- Parameters for experimental features ---------------------- //
  string graph_dir = 16;
  int32 num_select_features = 17;

  // When using a FixedSizeSparseClassificationGrowStats, keep track of
  // this many classes.
  int32 num_classes_to_track = 24;
}
